import time

import tensorrt as trt
import cv2

from yolo import Yolo

if __name__ == '__main__':
    yolo = Yolo("engine/yolox_m.engine", "config/yolox.yaml", trt.Logger.ERROR)
    #---------------------------------------------------------#
    #    mode用于指定运行的模式
    #       "predict"      表示单张图片的预测
    #       "video"        表示视频预测，可预测视频文件或摄像头
    #---------------------------------------------------------#
    mode = "video"

    #---------------------------------------------------------#
    #    "show_type"       指定显示模式,可选择cv2或PIL
    #    "print_info"      指定是否打印预测结果
    #---------------------------------------------------------#
    show_type = "PIL"
    print_info = True

    # ---------------------------------------------------------#
    #    "video_path"      指定视频路径,0表示使用摄像头
    #    "video_save_path" 指定视频保存位置，如果为空表示不保存视频
    #    "video_fps"       指定保存视频的帧率
    # ---------------------------------------------------------#
    video_path = 0
    video_save_path = ""
    video_fps = 25.

    if mode == "predict":
        while True:
            img_path = input("Input file path:")

            if img_path == "quit":
                print("Stop prediction")
                break

            try:
                img = cv2.imread(img_path)
            except:
                print("Open error, please check your image path.")
                continue

            start = time.time()
            result = yolo.predict(img)
            waist_time = time.time() - start
            print("waist time: {:.5f}".format(waist_time))

            yolo.show_result(img, result, print_info=print_info, show_type=show_type)

    elif mode == "video":
        cap = cv2.VideoCapture(video_path)
        if video_save_path != "":
            fourcc = cv2.VideoWriter_fourcc(*'XVID')
            size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
            out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)

        ref, frame = cap.read()
        if not ref:
            raise ValueError("Open error, please check your video path or camera.")

        fps = 0.
        while True:
            t1 = time.time()
            ref, frame = cap.read()
            if not ref:
                break

            result = yolo.predict(frame)
            frame = yolo.draw_bboxes(frame, result, print_info=print_info)

            fps = (fps + (1. / (time.time() - t1))) / 2
            print("fps= %.2f" % (fps))
            frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

            cv2.imshow("video", frame)
            if video_save_path != "":
                out.write(frame)

            if cv2.waitKey(1) & 0xff == ord("q"):
                break

        print("Video Detection Done!")
        cap.release()
        if video_save_path != "":
            print("Save processed video to the path :" + video_save_path)
            out.release()
        cv2.destroyAllWindows()

    else:
        raise AssertionError("Please specify the correct mode: 'predict', 'video'.")
